import numpy as np

def uni(X):
    """ X can be either a (N, d) matrix or a (d,) vector. """
    return (X / np.sqrt(np.atleast_2d(np.matmul(X, X.T)).diagonal()).reshape(-1, 1)).reshape(X.shape)

def normal(size, random_state=None):
    if random_state == None:
        x = np.random.normal(size=size)
    else:
        x = random_state.normal(size=size)
    return x

def uni_normal(size, random_state=None):
    x = normal(size, random_state=random_state)
    return uni(x)

def create_oracle(type, K):

    def optimal_cascade(utility):
        """ utility = (N,) """
        cascade = np.argsort(utility)[::-1][:K]

        return cascade

    if type == 'optimal':
        return optimal_cascade
    else:
        pass

def first_one_index(arr):
    idx = np.argmax(arr)
    return idx if arr[idx] == 1 else -1

def convert_binary_to_onehot(arr):
    arr = np.asarray(arr)
    return np.stack([arr, 1 - arr], axis=1)  

def realize(K, random_state):

    def cascading(utility):
        """ utility = (K,) """
        v = np.exp(utility)
        v_one = 1 + v
        prob = v / v_one
        outside = 1 - prob
        
        cascade_ereward = 1 - np.prod(outside)
        
        if random_state == None:
            Y = np.random.binomial(n=1, p=prob) # (K,)
        else:
            Y = random_state.binomial(n=1, p=prob) # (K,)
         
        stop = first_one_index(Y) 
        # the index of base arm with the first feedback 1 
        # stop = -1 if all the base arms get feedback 0
        
        if stop == -1:
            stop = K-1
            cascade_reward = 0
        else: 
            cascade_reward = 1

        Ys = convert_binary_to_onehot(Y) # (K, 2)

        return cascade_ereward, cascade_reward, Ys[:(stop+1)], stop
    
    return cascading

if __name__ == "__main__":
    pass